{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Network In Network (NIN)\n",
    "\n",
    "\n",
    "[Network in Network（NiN）](https://arxiv.org/abs/1312.4400) 的架构是紧跟着AlexNet的，但是这篇论文中有两个想法，值得我们去吸收和借鉴。\n",
    "\n",
    "\n",
    "1、使用了1×1的卷积，来对通道层做全链接操作，并且共享权重，这样的全连接操作，让我们可以保留空间信息。 【这样的做法在NIN中称为 mlpconv】这样的做法，还让以后的提出的模型，可以借助 1×1的卷积结构随意对 channel 进行升降。\n",
    "\n",
    "2、NiN在最后不是使用全连接，而是使用通道数为输出类别个数的 mlpconv ，外接一个平均池化层来将每个通道里的数值平均成一个标量。\n",
    "\n",
    "\n",
    "加一个自己的想法，这种 1×1的卷积 在我看来还可以**促进不同通道之间的信息交换**。\n",
    "\n",
    "![mlpconv](./images/nin.png)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.autograd import Variable\n",
    "from torch.utils.data import DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 超参数类，用于控制各种超参数\n",
    "class Config(object):\n",
    "    def __init__(self):\n",
    "        self.lr = 0.005\n",
    "        self.batch_size = 256\n",
    "        self.use_gpu = torch.cuda.is_available()\n",
    "        self.DOWNLOAD = True\n",
    "        self.epoch_num = 5 # 因为只是demo，就跑了2个epoch，可以自己多加几次试试结果\n",
    "        self.class_num = 10 # CIFAR10 共有10类\n",
    "config = Config()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# NiN提出只对通道层做全连接并且像素之间共享权重来解决上述两个问题\n",
    "# 这种“一卷卷到底”最后加一个平均池化层的做法也成为了深度卷积神经网络的常用设计。\n",
    "def mlpconv(in_chanels, out_chanels, kernel_size, padding, strides=1, max_pooling=True):\n",
    "    layers = []\n",
    "    layers += [nn.Conv2d(in_chanels, out_chanels, kernel_size=kernel_size, padding=padding, stride=strides), nn.ReLU(inplace=True)]\n",
    "    layers += [nn.Conv2d(out_chanels, out_chanels, kernel_size=1, padding=0, stride=1), nn.ReLU(inplace=True)]\n",
    "    layers += [nn.Conv2d(out_chanels, out_chanels, kernel_size=1, padding=0, stride=1), nn.ReLU(inplace=True)]\n",
    "    if max_pooling:\n",
    "        layers += [nn.MaxPool2d(kernel_size=3, stride=2)]\n",
    "    return nn.Sequential(*layers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class NIN(nn.Module):\n",
    "    \"\"\"\n",
    "       输入图片的尺寸一定得是 224 × 224 的\n",
    "    \"\"\"\n",
    "    def __init__(self, class_num):\n",
    "        super(NIN, self).__init__()\n",
    "        self.net = nn.Sequential(\n",
    "            mlpconv(3, 96, 11, 0, strides=4),\n",
    "            mlpconv(96, 256, 5, 2),\n",
    "            mlpconv(256, 384, 3, 1),\n",
    "            nn.Dropout(0.5),\n",
    "            # 目标类为10类\n",
    "            mlpconv(384, 10, 3, 1, max_pooling=False),\n",
    "            # 输入为 batch_size x 10 x 5 x 5, 通过AvgPool2D转成\n",
    "            # batch_size x 10 x 1 x 1。\n",
    "            nn.AvgPool2d(kernel_size=5, stride=1)\n",
    "        )\n",
    "        self.class_num = class_num\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = self.net(x)\n",
    "        out = out.view(-1, self.class_num )\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 图像预处理，因为NIN是使用224 * 224大小的图片，但是MNIST只有32 * 32\n",
    "transform = transforms.Compose([\n",
    "    transforms.Resize(224), # 缩放到 224 * 224 大小\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 归一化\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 下载 CIFAR10 dataset\n",
    "train_dataset = torchvision.datasets.CIFAR10(root='./data/', train=True, transform=transform, download=config.DOWNLOAD)\n",
    "test_dataset = torchvision.datasets.CIFAR10(root='./data/', train=False, transform=transform)\n",
    "\n",
    "# dataloader\n",
    "\n",
    "train_loader = DataLoader(dataset=train_dataset,\n",
    "                          batch_size=config.batch_size,\n",
    "                          shuffle=True)\n",
    "test_loader = DataLoader(dataset=test_dataset,\n",
    "                         batch_size=config.batch_size,\n",
    "                         shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "nin = NIN(config.class_num)\n",
    "\n",
    "# 是否使用GPU\n",
    "if config.use_gpu:\n",
    "    nin = nin.cuda()\n",
    "\n",
    "# loss and optimizer\n",
    "loss_fn = nn.CrossEntropyLoss()\n",
    "\n",
    "optimizer = torch.optim.Adam(nin.parameters(), lr=config.lr)\n",
    "\n",
    "for epoch in range(config.epoch_num):\n",
    "    total = 0\n",
    "    correct = 0\n",
    "    for i, (images, labels) in enumerate(train_loader):\n",
    "        images = Variable(images)\n",
    "        labels = Variable(labels)\n",
    "\n",
    "        if config.use_gpu:\n",
    "            images = images.cuda()\n",
    "            labels = labels.cuda()\n",
    "\n",
    "        # forward + backward + optimize\n",
    "        optimizer.zero_grad()\n",
    "        y_pred = nin(images)\n",
    "\n",
    "        loss = loss_fn(y_pred, labels)\n",
    "\n",
    "        loss.backward()\n",
    "\n",
    "        optimizer.step()\n",
    "        \n",
    "        if (i + 1) % 100 == 0:\n",
    "            print(\"Epoch [%d/%d], Iter [%d/%d] Loss: %.4f\" % (epoch + 1, config.epoch_num, i + 1, 100, loss.data[0]))\n",
    "\n",
    "        # 计算训练精确度\n",
    "        _, predicted = torch.max(y_pred.data, 1)\n",
    "        total += labels.size(0)\n",
    "        correct += (predicted == labels.data).sum()\n",
    "    # 结束一次迭代\n",
    "    print('Accuracy of the model on the train images: %d %%' % (100 * correct / total))\n",
    " \n",
    "    # Decaying Learning Rate\n",
    "    if (epoch+1) % 2 == 0:\n",
    "        config.lr /= 3\n",
    "        optimizer = torch.optim.Adam(nin.parameters(), lr=config.lr)\n",
    "\n",
    "\n",
    "# Test\n",
    "nin.eval()\n",
    "\n",
    "correct = 0\n",
    "total = 0\n",
    "\n",
    "for images, labels in test_loader:\n",
    "    images = Variable(images)\n",
    "    labels = Variable(labels)\n",
    "    if config.use_gpu:\n",
    "        images = images.cuda()\n",
    "        labels = labels.cuda()\n",
    "    y_pred = nin(images)\n",
    "    _, predicted = torch.max(y_pred.data, 1)\n",
    "    total += labels.size(0)\n",
    "    temp = (predicted == labels.data).sum()\n",
    "    correct += temp\n",
    "\n",
    "\n",
    "print('Accuracy of the model on the test images: %d %%' % (100 * correct / total))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "\n",
    "在我自己电脑上跑出来的结果是：\n",
    "\n",
    "```\n",
    "Files already downloaded and verified\n",
    "Epoch [1/5], Iter [100/100] Loss: 2.2142\n",
    "Accuracy of the model on the train images: 15 %\n",
    "Epoch [2/5], Iter [100/100] Loss: 2.0748\n",
    "Accuracy of the model on the train images: 23 %\n",
    "Epoch [3/5], Iter [100/100] Loss: 1.8296\n",
    "Accuracy of the model on the train images: 33 %\n",
    "Epoch [4/5], Iter [100/100] Loss: 1.7573\n",
    "Accuracy of the model on the train images: 37 %\n",
    "Epoch [5/5], Iter [100/100] Loss: 1.6143\n",
    "Accuracy of the model on the train images: 42 %\n",
    "Accuracy of the model on the test images: 43 %\n",
    "\n",
    "```\n",
    "\n",
    "看上去似乎还没Alexnet强，其实正常...因为我们模型的参数便多了，但是训练次数却变多，很多参数得不到充分的训练"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
